from torch import nn
from torchvision import models
import os

from .Masking import masking
from .Network import Amend


class ResNet18ARMMasking(nn.Module):
    def __init__(self, pretrained=True, num_classes=7, drop_rate=0, feed_huawei=None):
        super(ResNet18ARMMasking, self).__init__()
        self.drop_rate = drop_rate

        if feed_huawei is not None:
            os.environ['TORCH_HOME'] = feed_huawei

        self.resnet = models.resnet18(pretrained=False)

        # ARM
        self.features = nn.Sequential(*list(self.resnet.children())[:-2])
        self.arrangement = nn.PixelShuffle(16)
        self.arm = Amend()
        self.fc = nn.Linear(625, num_classes)

        # Masking
        self.block1 = nn.Sequential(*list(self.resnet.children())[:3])
        self.mask1 = masking(64, 64, depth=4)
        self.mask2 = masking(128, 128, depth=3)
        self.mask3 = masking(256, 256, depth=2)
        self.mask4 = masking(512, 512, depth=1)

    def forward(self, x):
        # [batch_size, 3, 224, 224]
        x = self.block1(x)  # [batch_size, 64, 112, 112]

        x = self.resnet.layer1(x)  # [batch_size, 64, 112, 112]
        m = self.mask1(x)  # [batch_size, 64, 112, 112]
        x = x * (1 + m)  # [batch_size, 64, 112, 112]

        x = self.resnet.layer2(x)  # [batch_size, 128, 56, 56]
        m = self.mask2(x)  # [batch_size, 128, 56, 56]
        x = x * (1 + m)  # [batch_size, 128, 56, 56]

        x = self.resnet.layer3(x)  # [batch_size, 256, 28, 28]
        m = self.mask3(x)  # [batch_size, 256, 28, 28]
        x = x * (1 + m)  # [batch_size, 256, 28, 28]

        x = self.resnet.layer4(x)  # [batch_size, 512, 14, 14]
        m = self.mask4(x)  # [batch_size, 512, 14, 14]
        x = x * (1 + m)  # [batch_size, 512, 14, 14]

        x = self.arrangement(x)  # [batch_size, 2, 224, 224]
        x, alpha = self.arm(x)  # [batch_size, 1, 25, 25]

        if self.drop_rate > 0:
            x = nn.Dropout(self.drop_rate)(x)

        x = x.view(x.size(0), -1)  # [batch_size, 625]
        out = self.fc(x)

        return out, alpha
